import os
import json
import pickle
from utils import *
import argparse
from rule_config import rule_hyperparams

def is_match(res_line):
    res_line = json.loads(res_line)
    prediction = res_line['prediction']
    hole = res_line['ground-truth hole']
    pred = prediction.rstrip()
    hole = hole.rstrip()
    # there is an exact match corresponding to this hole id
    if pred == hole:
        return True
    else:
        return False

def get_hole_identities(capped_hole_filename):
    capped_holes = open(capped_hole_filename, 'r').readlines()
    capped_holes = [x.strip() for x in capped_holes]
    return capped_holes

def update_hole_rule_mapping(hid, hole_rule_mapping, rule_parts):
  if hid in hole_rule_mapping:
      hole_rule_mapping[hid].append(rule_parts)
  else:
      hole_rule_mapping[hid] = [rule_parts]
  return hole_rule_mapping

def get_hids(lines, hole_identities):
  hids = []
  mod_lines = []
  for i in range(len(lines)):
    hid = json.loads(lines[i])['hole_identity']
    if hid in hole_identities:
      hids.append(hid)
      mod_lines.append(lines[i])
  return hids, mod_lines

def read_result_file(rule_result_file, oracle, hole_identities, hole_rule_mapping, rule_parts):
  rule_lines = open(rule_result_file, 'r').readlines()
  rule_hids, mod_rule_lines = get_hids(rule_lines, hole_identities)
  for i in range(len(hole_identities)):
    hid = hole_identities[i]
    if hid in rule_hids:
      # use rule result
      res_line = mod_rule_lines[rule_hids.index(hid)]
      if is_match(res_line):
        hole_rule_mapping = update_hole_rule_mapping(hid, hole_rule_mapping, rule_parts)
    else:
      # use codex result
      if hid in oracle:
        codex_match = oracle[hid]['com'][62]
      else:
        codex_match = 0
      if codex_match == 1:
        hole_rule_mapping = update_hole_rule_mapping(hid, hole_rule_mapping, rule_parts)
  return hole_rule_mapping

def get_results(base_result_dir, context_location):
  context_result_dir = os.path.join(base_result_dir, context_location)
  result_files = next(os.walk(context_result_dir), (None, None, []))[2]  # [] if no file
  mod_result_files = [os.path.join(context_result_dir, result_file) for result_file in result_files if result_file]
  result_files = [f for f in mod_result_files if os.path.getsize(f)>0]
  return result_files

def get_all_hole_rule_mapping(base_result_dir, hole_identities, context_location, oracle):
  result_files = get_results(base_result_dir, context_location)
  hole_rule_mapping = {}
  for result_file in result_files:
    hole_rule_mapping = read_result_file(result_file, oracle, hole_identities, hole_rule_mapping, \
                                        (context_location, 'random', '0.5'))
  return hole_rule_mapping


def setup_args():
  """
  Description: Takes in the command-line arguments from user
  """
  parser = argparse.ArgumentParser()
  parser.add_argument("--base_dir", type=str, default='rule_classifier_data', help="base directory for the data")
  parser.add_argument("--data_split", type=str, default='test', help="data split to store the data")
  parser.add_argument("--proj_name", type=str, default='dovetaildb', help="name of the input repo")
  parser.add_argument("--results_dir", type=str, default='results', help="name of the input repo")
  parser.add_argument("--context_location", type=str, default='random_file_NN', help="name of the input repo")
  return parser.parse_args()

if __name__ == '__main__':

  args = setup_args()
  hole_filename = os.path.join(args.base_dir, args.data_split, args.proj_name, 'hole_data')
  capped_hole_filename = os.path.join(args.base_dir, args.data_split, args.proj_name, 'capped_holes_10000')
  hole_identities = get_hole_identities(capped_hole_filename)
  base_result_dir = os.path.join(args.results_dir, args.base_dir, args.data_split, args.proj_name)
  oracle = pickle.load(open(os.path.join(args.base_dir, args.data_split, args.proj_name, 'capped_oracle_10000'), 'rb'))
  successful_holes = get_all_hole_rule_mapping(base_result_dir, hole_identities, args.context_location, oracle)
  print(args.proj_name + ", " + \
        str(float(len(successful_holes)*100/len(hole_identities))))
